""""
Some common functions for Reinforcement Learning calculations.

Author: Meng Zhang
Date: January 2024
Disclaimer: adapted from the analysis code https://doi.org/10.4121/22153898.v1
"""

import pandas as pd
import numpy as np
import math
from copy import deepcopy
import random
def policy_evaluation(observation_space_size, discount_factor, trans_func,
                      reward_func, policy, q_vals,
                      update_tolerance=0.000000001,
                      reward_dep_next_state=False):
    """
    Returns the q-values for a specific policy.

    Args:
        observation_space_size: number of observations
        discount_factor: discount factor of MDP
        trans_func: transition function
        reward_func: reward function
        policy: policy to compute q-values for
        update_tolerance: precision
        reward_dep_next_state (boolean, default: False): whether reward depends also on the next state
    """
    q_vals_new = deepcopy(q_vals)
    num_act = len(trans_func[0])

    update = 1  # max. change in Q-value in one iteration

    if not reward_dep_next_state:  # reward does not also depend on next state
        while update > update_tolerance:
            update = 0
            for s in range(observation_space_size):
                for a in range(num_act):
                    q_vals_new[s, a] = reward_func[s, a] + discount_factor * sum(
                        [trans_func[s, a, s_prime] * q_vals[s_prime, int(policy[s_prime])] for s_prime in
                         range(observation_space_size)])
                    update = max(update, abs(q_vals[s, a] - q_vals_new[s, a]))
            q_vals = deepcopy(q_vals_new)

    else:  # reward does also depend on the next state
        while update > update_tolerance:
            update = 0
            for s in range(observation_space_size):

                for a in range(num_act):
                    q_vals_new[s, a] = sum([trans_func[s, a, s_prime] * (
                                discount_factor * q_vals[s_prime, int(policy[s_prime])] + reward_func[s, a, s_prime])
                                            for s_prime in range(observation_space_size)])

                    update = max(update, abs(q_vals[s, a] - q_vals_new[s, a]))
            q_vals = deepcopy(q_vals_new)

    return q_vals_new


def get_Q_values_opt_policy(discount_factor, trans_func, reward_func,
                            reward_dep_next_state=False):
    """
    Returns the Q-values for each state under the optimal policy as well
    as the optimal policy.

    Args:
        discount_factor: discount factor of MDP
        trans_func: transition function (dim.: num_states x num_actions x num_states)
        reward_func: reward function (dim.: num_states x num_actions)
        reward_dep_next_state (boolean, default: False): whether reward depends also on the next state
    """
    min_iterations = 100
    num_states = len(trans_func)
    num_act = len(trans_func[0])
    q_vals = np.zeros((num_states, num_act))

    policy = np.zeros(num_states)
    policy_new = np.ones(num_states)
    it = 0

    while (not np.array_equal(policy, policy_new)) or it < min_iterations:
        q_vals = policy_evaluation(num_states, discount_factor,
                                   trans_func, reward_func, policy, q_vals,
                                   reward_dep_next_state=reward_dep_next_state)
        policy = policy_new
        policy_new = np.array([np.argmax(q_vals[s]) for s in range(num_states)])
        it += 1

    return q_vals, policy_new



def get_map_effort_reward(effort_mean, output_lower_bound,
                          output_upper_bound, input_lower_bound,
                          input_upper_bound):
    """
    Computes a mapping from effort responses to rewards.

    Args:
        reward_mean (float): mean of the weight sum of the effort response and dropout response,
                            to be mapped to halfway between output_lower_bound and output_upper_bound
        output_lower_bound (float): lowest value on output scale
        output_upper_bound (float): highest value on output scale
        input_lower_bound (int): lowest value on input scale, to
                                            be mapped to output_lower_bound
        input_upper_bound (int): highest value on input scale,
                                             to be mapped to output_upper_bound

    Returns:
        dictionary: maps effort responses (integers) to output scale values (float)
    """
    map_to_rew = {}

    # We can already map the endpoints of the input scale to the output scale
    map_to_rew[input_lower_bound] = output_lower_bound
    map_to_rew[input_upper_bound] = output_upper_bound

    # The mean value on the output scale
    mean_output = (output_upper_bound - output_lower_bound) / 2 + output_lower_bound
    output_length_half = mean_output - output_lower_bound

    input_length_lower_half = effort_mean - input_lower_bound
    input_length_upper_half = input_upper_bound - effort_mean

    inc_lower_half = output_length_half / input_length_lower_half
    inc_upper_half = output_length_half / input_length_upper_half

    # Compute output scale values for input scale below mean.
    idx = 1
    for i in range(input_lower_bound + 1, int(np.ceil(effort_mean))):
        map_to_rew[i] = output_lower_bound + idx * inc_lower_half
        idx += 1

    # Compute output scale values for input scale above mean.
    idx = 1
    for i in range(input_upper_bound - 1, int(np.floor(effort_mean)), -1):
        map_to_rew[i] = output_upper_bound - idx * inc_upper_half
        idx += 1

    if np.floor(effort_mean) == effort_mean:
        map_to_rew[effort_mean] = mean_output

    # Need to check whether effort_mean is an integer that we need to map
    # to mean_output on the output scale.
    return map_to_rew
#
def map_efforts_to_rewards(list_of_weigthed_sum, map_to_rewards):
    """
    Maps the weighted sum of effort and dropout responses to rewards.

    Args:
        list_of_weigthed_sum (list of int): the weighted sum to be mapped
        map_to_rewards (dictionary): maps the weighted to reward values (float)

    Returns:
        list (float): resulting reward values
    """
    rewards = []
    for e in list_of_weigthed_sum:
        rewards.append(map_to_rewards[e])

    return rewards


def weighted_sum_of_reward__for_transitions(effort_weight):
    """
    Sum of dropout and effort to a weighted reward
    Args:
        effort_weight: the weight of the effort for the sum

    Returns:
        transitions_df: the new dataframe with weighted_reward adding and the mean of the weighted sum
        weighted_mean(float): the mean of the weighted sum
        weighted_min(int) : upper bound of the weighted sum of effort and dropout
        weighted_max(int): lower bound of the weighted sum of effort and dropout response
    """
    transitions_df = pd.read_csv("RL_trasition_samples.csv")
    efforts = transitions_df['effort']
    dropouts = transitions_df['dropout_response']
    weighted_reward = []
    # calculate the weighted sum of reward
    for row in range(len(transitions_df)):
        weighted_reward_row = int(efforts[row] * effort_weight + dropouts[row] * (1 - effort_weight))
        weighted_reward.append(weighted_reward_row)

    weighted_max = math.ceil(max(weighted_reward))
    weighted_min = math.floor(min(weighted_reward))
    weighted_mean = round(sum(weighted_reward) / len(weighted_reward), 2)
    transitions_df['weighted_reward'] = weighted_reward
    transitions_df = transitions_df.drop(columns=['effort', 'dropout_response'])
    transitions_df.to_csv("RL_trasition_weighted_reward.csv", index = False)
    return transitions_df, weighted_mean, weighted_min, weighted_max


def line_up_with_index(df):
    """"
    line the index for activity cluster starting with 0

    Args:
        df: the data frame
    Returns:
        df: the data frame with the value for activity cluster starting from 0 instead of 1

    """
    cluster_index_starts_0 = []
    for row in range(len(df)):
        cluster_index_starts_0.append(df.iloc[row, 2] - 1)
    df["cluster_new_index"] = cluster_index_starts_0
    return df

def get_opt_policy_without_repeat(q_values):
    """"
    Remove repetitions of actions in the optimal policy

    Args:
        q_values: the q_values tables
    Retrun:
        opt_policy_remove_repeat: the optimal policy without repetition of actions
        q_values_max: the list with corresponding q_values for the optimal policy
    """
    q_values_max = [np.max(q_values[s]) for s in range(len(q_values))]
    # Optimal policy
    opt_policy = [np.argmax(s) for s in q_values]
    opt_policy_remove_repeat = []
    #  check repetition, assign the second one to the second state
    sec_q_values_max = np.argsort(q_values, axis=1)[:, -2]
    for i in range(len(opt_policy)):
        # if the optimal policy is repetitive
        if opt_policy.count(opt_policy[i]) > 1:
            # if it is already chosen
            if opt_policy_remove_repeat.count(opt_policy[i]) > 0:
                opt_policy_remove_repeat.append(sec_q_values_max[i])
                opt_policy[i] = sec_q_values_max[i]
            else:
                choices = [opt_policy[i], sec_q_values_max[i]]
                choice = random.choice(choices)
                opt_policy[i] = choice
                opt_policy_remove_repeat.append(choice)
        else:
            opt_policy_remove_repeat.append(opt_policy[i])
    print("Optimal policy remove repeatition:", opt_policy_remove_repeat, "\n")
    for i in range(len(opt_policy_remove_repeat)):
        q_values_max[i] = q_values[i, opt_policy_remove_repeat[i]]
    return opt_policy_remove_repeat, q_values_max
